Skip to content

[MLIR] Integration tests for lowering vector.contract to SVE FEAT_I8MM #140573

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: users/momchil-velikov/vector-contract-i8mm-transform-dialect
Choose a base branch
from

Conversation

momchil-velikov
Copy link
Collaborator

No description provided.

@llvmbot
Copy link
Member

llvmbot commented May 19, 2025

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir-sve

Author: Momchil Velikov (momchil-velikov)

Changes

Patch is 30.80 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140573.diff

5 Files Affected:

  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir (+117)
  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir (+159)
  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir (+118)
  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir (+119)
  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-usmmla-4x8x4.mlir (+117)
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir
new file mode 100644
index 0000000000000..88534dd2aab1e
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir
@@ -0,0 +1,117 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE:   --convert-vector-to-scf --convert-scf-to-cf  --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \
+// DEFINE:   --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm  --reconcile-unrealized-casts \
+// DEFINE: -o %t
+
+// DEFINE: %{entry_point} = main
+
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void  --march=aarch64 --mattr="+sve,+i8mm" \
+// DEFINE:    -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
+
+// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
+
+#packed_maps = [
+  affine_map<(d0, d1, d2) -> (d0, d2)>,
+  affine_map<(d0, d1, d2) -> (d1, d2)>,
+  affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+
+func.func private @setArmVLBits(%bits : i32)
+
+func.func @main() {
+  %c128 = arith.constant 128 : i32
+  func.call @setArmVLBits(%c128) : (i32) -> ()
+
+  %c0 = arith.constant 0 : index
+  %c0_i32 = arith.constant 0 : i32
+  %c0_i8 = arith.constant 0 : i8
+
+// Accumulator test data
+  %acc_cst = arith.constant dense<[[-44,  20,  44, -46],
+                                   [ -8,  25, -34,  26],
+                                   [-20, -36,  -3,  39],
+                                   [-48, -31, -25, -21]]> : vector<4x4xi32>
+  %acc_m = memref.alloca() : memref<4x4xi32>
+  vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<4x4xi32>, memref<4x4xi32>
+
+  %acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<4x4xi32> into memref<16xi32>
+  %acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<16xi32>, vector<[16]xi32>
+  %acc = vector.shape_cast %acc_flat : vector<[16]xi32> to vector<4x[4]xi32>
+
+  vector.print str "ACC:\n"
+  %acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<4x[4]xi32>
+  %acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<4x[4]xi32>
+  %acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<4x[4]xi32>
+  %acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<4x[4]xi32>
+  vector.print %acc0 : vector<[4]xi32>
+  vector.print %acc1 : vector<[4]xi32>
+  vector.print %acc2 : vector<[4]xi32>
+  vector.print %acc3 : vector<[4]xi32>
+
+  // LHS test data
+  %lhs_cst = arith.constant dense<[[-35, -27, -36, -31,  23, -34,  -8, -33],
+                                   [-20,  17, -32, -47,  37,  22,  -7, -21],
+                                   [ -7, -35,  20,  -4,  39,  46, -23,  40],
+                                   [ 40,  27,  37,  43,  38,  -6,  37,  49]]> : vector<4x8xi8>
+
+  %lhs_m = memref.alloca() : memref<4x8xi8>
+  vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
+  %lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<4x8xi8>, vector<4x8xi8>
+
+  vector.print str "LHS:\n"
+  %lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<4x8xi8>
+  %lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<4x8xi8>
+  %lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<4x8xi8>
+  %lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<4x8xi8>
+  vector.print %lhs0 : vector<8xi8>
+  vector.print %lhs1 : vector<8xi8>
+  vector.print %lhs2 : vector<8xi8>
+  vector.print %lhs3 : vector<8xi8>
+
+  // RHS test data
+  %rhs_cst = arith.constant dense<[[-17, -50,  -1,  48, -13,  22,  39,  33],
+                                   [-35, -24,  37, -32,  33,  30, -11, -17],
+                                   [-28,  31,   3, -44, -15, -27,  22,  35],
+                                   [-23,  39,  48,  26, -23,  32, -39, -38]]> : vector<4x8xi8>
+
+  %rhs_m = memref.alloca() : memref<4x8xi8>
+  vector.transfer_write %rhs_cst, %rhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
+
+  %rhs_m1 = memref.collapse_shape %rhs_m [[0, 1]] : memref<4x8xi8> into memref<32xi8>
+  %rhs_flat = vector.transfer_read %rhs_m1[%c0], %c0_i8 {in_bounds = [true]} : memref<32xi8>, vector<[32]xi8>
+
+  vector.print str "RHS:\n"
+  %rhs0 = vector.scalable.extract %rhs_flat[0] : vector<[16]xi8> from vector<[32]xi8>
+  %rhs1 = vector.scalable.extract %rhs_flat[16] : vector<[16]xi8> from vector<[32]xi8>
+  vector.print %rhs0 : vector<[16]xi8>
+  vector.print %rhs1 : vector<[16]xi8>
+
+  %rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
+
+  // Matrix multiplication
+  %0 = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
+  %1 = arith.extsi %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
+  %2 = vector.contract {indexing_maps = #packed_maps,
+                        iterator_types = ["parallel", "parallel", "reduction"],
+                        kind = #vector.kind<add>} %0, %1, %acc
+    : vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32>
+
+  // Display the result of the multiplication
+  vector.print str "Result:\n"
+  %u0 = vector.extract %2[0] : vector<[4]xi32> from vector<4x[4]xi32>
+  %u1 = vector.extract %2[1] : vector<[4]xi32> from vector<4x[4]xi32>
+  %u2 = vector.extract %2[2] : vector<[4]xi32> from vector<4x[4]xi32>
+  %u3 = vector.extract %2[3] : vector<[4]xi32> from vector<4x[4]xi32>
+  vector.print %u0 : vector<[4]xi32>
+  vector.print %u1 : vector<[4]xi32>
+  vector.print %u2 : vector<[4]xi32>
+  vector.print %u3 : vector<[4]xi32>
+
+// CHECK: ( -1999,  1941,   685, -2879 )
+// CHECK: ( -3705,  2952,   987,  -685 )
+// CHECK: (  2565,  4157, -1589,  -357 )
+// CHECK: (  2383, -2252,    32, -1365 )
+  return
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir
new file mode 100644
index 0000000000000..ce57be91fa540
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir
@@ -0,0 +1,159 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE:   --convert-vector-to-scf --convert-scf-to-cf  --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \
+// DEFINE:   --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm  --reconcile-unrealized-casts \
+// DEFINE: -o %t
+
+// DEFINE: %{entry_point} = main
+
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void  --march=aarch64 --mattr="+sve,+i8mm" \
+// DEFINE:    -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
+
+// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
+
+#packed_maps = [
+  affine_map<(d0, d1, d2) -> (d0, d2)>,
+  affine_map<(d0, d1, d2) -> (d1, d2)>,
+  affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+
+func.func private @setArmVLBits(%bits : i32)
+
+func.func @main() {
+  %c256 = arith.constant 256 : i32
+  func.call @setArmVLBits(%c256) : (i32) -> ()
+
+  %c0 = arith.constant 0 : index
+  %c0_i32 = arith.constant 0 : i32
+  %c0_i8 = arith.constant 0 : i8
+
+
+  // Accumulator test data
+  %acc_cst = arith.constant dense<[[-44,  20,  44, -46,  -8,  25, -34,  26],
+                                   [-20, -36,  -3,  39, -48, -31, -25, -21],
+                                   [-35, -27, -36, -31,  23, -34,  -8, -33],
+                                   [-20,  17, -32, -47,  37,  22,  -7, -21],
+                                   [ -7, -35,  20,  -4,  39,  46, -23,  40],
+                                   [ 40,  27,  37,  43,  38,  -6,  37,  49],
+                                   [-17, -50,  -1,  48, -13,  22,  39,  33],
+                                   [-35, -24,  37, -32,  33,  30, -11, -17]]> : vector<8x8xi32>
+  %acc_m = memref.alloca() : memref<8x8xi32>
+  vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<8x8xi32>, memref<8x8xi32>
+
+  %acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<8x8xi32> into memref<64xi32>
+  %acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<64xi32>, vector<[32]xi32>
+  %acc = vector.shape_cast %acc_flat : vector<[32]xi32> to vector<8x[4]xi32>
+
+  vector.print str "ACC:\n"
+  %acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<8x[4]xi32>
+  %acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<8x[4]xi32>
+  %acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<8x[4]xi32>
+  %acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<8x[4]xi32>
+  %acc4 = vector.extract %acc[4] : vector<[4]xi32> from vector<8x[4]xi32>
+  %acc5 = vector.extract %acc[5] : vector<[4]xi32> from vector<8x[4]xi32>
+  %acc6 = vector.extract %acc[6] : vector<[4]xi32> from vector<8x[4]xi32>
+  %acc7 = vector.extract %acc[7] : vector<[4]xi32> from vector<8x[4]xi32>
+  vector.print %acc0 : vector<[4]xi32>
+  vector.print %acc1 : vector<[4]xi32>
+  vector.print %acc2 : vector<[4]xi32>
+  vector.print %acc3 : vector<[4]xi32>
+  vector.print %acc4 : vector<[4]xi32>
+  vector.print %acc5 : vector<[4]xi32>
+  vector.print %acc6 : vector<[4]xi32>
+  vector.print %acc7 : vector<[4]xi32>
+
+  // LHS test data
+  %lhs_cst = arith.constant dense<[[-28,  31,   3, -44, -15, -27,  22,  35],
+                                   [-23,  39,  48,  26, -23,  32, -39, -38],
+                                   [ -3,   9,  43, -30, -32,  39,  41, -39],
+                                   [-13, -21, -25,  27,  47, -36, -11, -11],
+                                   [ -4, -20,  36,  11,  13, -23,  24, -13],
+                                   [-20,  30,  -5,   1,  42, -37, -22,  35],
+                                   [-22,  38,  -4,  44,  25, -31,  23, -39],
+                                   [-45,  -4, -31, -24,  14, -41, -47,  22]]> : vector<8x8xi8>
+
+  %lhs_m = memref.alloca() : memref<8x8xi8>
+  vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<8x8xi8>, memref<8x8xi8>
+  %lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<8x8xi8>, vector<8x8xi8>
+
+  vector.print str "LHS:\n"
+  %lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<8x8xi8>
+  %lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<8x8xi8>
+  %lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<8x8xi8>
+  %lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<8x8xi8>
+  %lhs4 = vector.extract %lhs[4] : vector<8xi8> from vector<8x8xi8>
+  %lhs5 = vector.extract %lhs[5] : vector<8xi8> from vector<8x8xi8>
+  %lhs6 = vector.extract %lhs[6] : vector<8xi8> from vector<8x8xi8>
+  %lhs7 = vector.extract %lhs[7] : vector<8xi8> from vector<8x8xi8>
+  vector.print %lhs0 : vector<8xi8>
+  vector.print %lhs1 : vector<8xi8>
+  vector.print %lhs2 : vector<8xi8>
+  vector.print %lhs3 : vector<8xi8>
+  vector.print %lhs4 : vector<8xi8>
+  vector.print %lhs5 : vector<8xi8>
+  vector.print %lhs6 : vector<8xi8>
+  vector.print %lhs7 : vector<8xi8>
+
+  // RHS test data
+  %rhs_cst = arith.constant dense<[[-40, -11, -36,  36,  -1,  20,  14, -32],
+                                   [ 46, -45, -48, -46, -24,  31, -36,  22],
+                                   [  2,  36,  45, -29, -37, -49, -20, -35],
+                                   [ -6,  23,  23,  15,  20,   4,  -8,  -2],
+                                   [-35,  -6,  16,  49, -50,   9, -44,  13],
+                                   [ 24,   1,  -4, -44,  41,  15, -43,  44],
+                                   [ 44,   0, -10,  41,  22,  44, -40,   0],
+                                   [-33,  19,  27,  22,  38, -17,  23,  -9]]> : vector<8x8xi8>
+
+  %rhs_m = memref.alloca() : memref<8x8xi8>
+  vector.transfer_write %rhs_cst, %rhs_m[%c0, %c0] : vector<8x8xi8>, memref<8x8xi8>
+
+  %rhs_m1 = memref.collapse_shape %rhs_m [[0, 1]] : memref<8x8xi8> into memref<64xi8>
+  %rhs_flat = vector.transfer_read %rhs_m1[%c0], %c0_i8 {in_bounds = [true]} : memref<64xi8>, vector<[32]xi8>
+
+  vector.print str "RHS:\n"
+  %rhs0 = vector.scalable.extract %rhs_flat[ 0] : vector<[16]xi8> from vector<[32]xi8>
+  %rhs1 = vector.scalable.extract %rhs_flat[16] : vector<[16]xi8> from vector<[32]xi8>
+  vector.print %rhs0 : vector<[16]xi8>
+  vector.print %rhs1 : vector<[16]xi8>
+
+  %rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
+
+  // Matrix multiplication
+  %0 = arith.extsi %lhs : vector<8x8xi8> to vector<8x8xi32>
+  %1 = arith.extsi %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
+  %2 = vector.contract {indexing_maps = #packed_maps,
+                        iterator_types = ["parallel", "parallel", "reduction"],
+                        kind = #vector.kind<add>} %0, %1, %acc
+    : vector<8x8xi32>, vector<[4]x8xi32> into vector<8x[4]xi32>
+
+  // Display the result of the multilication
+  vector.print str "Result:\n"
+  %u0 = vector.extract %2[0] : vector<[4]xi32> from vector<8x[4]xi32>
+  %u1 = vector.extract %2[1] : vector<[4]xi32> from vector<8x[4]xi32>
+  %u2 = vector.extract %2[2] : vector<[4]xi32> from vector<8x[4]xi32>
+  %u3 = vector.extract %2[3] : vector<[4]xi32> from vector<8x[4]xi32>
+  %u4 = vector.extract %2[4] : vector<[4]xi32> from vector<8x[4]xi32>
+  %u5 = vector.extract %2[5] : vector<[4]xi32> from vector<8x[4]xi32>
+  %u6 = vector.extract %2[6] : vector<[4]xi32> from vector<8x[4]xi32>
+  %u7 = vector.extract %2[7] : vector<[4]xi32> from vector<8x[4]xi32>
+  vector.print %u0 : vector<[4]xi32>
+  vector.print %u1 : vector<[4]xi32>
+  vector.print %u2 : vector<[4]xi32>
+  vector.print %u3 : vector<[4]xi32>
+  vector.print %u4 : vector<[4]xi32>
+  vector.print %u5 : vector<[4]xi32>
+  vector.print %u6 : vector<[4]xi32>
+  vector.print %u7 : vector<[4]xi32>
+
+
+// CHECK: ( -2294, -1282,  2728,  -410, -1328,   882, -5498,   732 )
+// CHECK: (  1012, -4237,  4154,  2624,  5225, -2338,  2011,  1374 )
+// CHECK: (    -8, -1611,  2905,    -1, -1068, -3155, -2428,   153 )
+// CHECK: (  2034, -1768, -2092,   284,  -792,   -23,   668,  2172 )
+// CHECK: (  -248, -3728,  1214,   555,  -668, -2114, -1794,  2560 )
+// CHECK: ( -1484, -2642,   297,  1551,  -483,  3173,  -576,  2570 )
+// CHECK: (  3098, -7851,  1366,  1892,  -427, -4533,  -819,  4698 )
+// CHECK: (  -135,  1247,   765,  -479,  1245,  3074, -2281,   -23 )
+  return
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir
new file mode 100644
index 0000000000000..f1f311ddb0c18
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir
@@ -0,0 +1,118 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE:   --convert-vector-to-scf --convert-scf-to-cf  --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \
+// DEFINE:   --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm  --reconcile-unrealized-casts \
+// DEFINE: -o %t
+
+// DEFINE: %{entry_point} = main
+
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void  --march=aarch64 --mattr="+sve,+i8mm" \
+// DEFINE:    -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
+
+// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
+
+#packed_maps = [
+  affine_map<(d0, d1, d2) -> (d0, d2)>,
+  affine_map<(d0, d1, d2) -> (d1, d2)>,
+  affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+
+func.func private @setArmVLBits(%bits : i32)
+
+func.func @main() {
+  %c128 = arith.constant 128 : i32
+  func.call @setArmVLBits(%c128) : (i32) -> ()
+
+  %c0 = arith.constant 0 : index
+  %c0_i32 = arith.constant 0 : i32
+  %c0_i8 = arith.constant 0 : i8
+
+// Accumulator test data
+  %acc_cst = arith.constant dense<[[-44,  20,  44, -46],
+                                   [ -8,  25, -34,  26],
+                                   [-20, -36,  -3,  39],
+                                   [-48, -31, -25, -21]]> : vector<4x4xi32>
+  %acc_m = memref.alloca() : memref<4x4xi32>
+  vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<4x4xi32>, memref<4x4xi32>
+
+  %acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<4x4xi32> into memref<16xi32>
+  %acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<16xi32>, vector<[16]xi32>
+  %acc = vector.shape_cast %acc_flat : vector<[16]xi32> to vector<4x[4]xi32>
+
+  vector.print str "ACC:\n"
+  %acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<4x[4]xi32>
+  %acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<4x[4]xi32>
+  %acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<4x[4]xi32>
+  %acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<4x[4]xi32>
+  vector.print %acc0 : vector<[4]xi32>
+  vector.print %acc1 : vector<[4]xi32>
+  vector.print %acc2 : vector<[4]xi32>
+  vector.print %acc3 : vector<[4]xi32>
+
+  // LHS test data
+  %lhs_cst = arith.constant dense<[[-35, -27, -36, -31,  23, -34,  -8, -33],
+                                   [-20,  17, -32, -47,  37,  22,  -7, -21],
+                                   [ -7, -35,  20,  -4,  39,  46, -23,  40],
+                                   [ 40,  27,  37,  43,  38,  -6,  37,  49]]> : vector<4x8xi8>
+
+  %lhs_m = memref.alloca() : memref<4x8xi8>
+  vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
+  %lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<4x8xi8>, vector<4x8xi8>
+
+  vector.print str "LHS:\n"
+  %lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<4x8xi8>
+  %lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<4x8xi8>
+  %lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<4x8xi8>
+  %lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<4x8xi8>
+  vector.print %lhs0 : vector<8xi8>
+  vector.print %lhs1 : vector<8xi8>
+  vector.print %lhs2 : vector<8xi8>
+  vector.print %lhs3 : vector<8xi8>
+
+  // RHS test data
+  %rhs_cst = arith.constant dense<[[125, 171, 138, 187, 108, 175,  82,  99],
+                                   [221,  25, 164,  97, 156, 221, 218, 177],
+                                   [171, 160, 219, 191, 144,  45, 161, 210],
+                                   [223, 165, 123,  99, 108,  86,  37,  92]]> : vector<4x8xi8>
+
+  %rhs_m = memref.alloca() : memref<4x8xi8>
+  vector.transfer_write %rhs_cst, %rhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
+
+  %rhs_m1 = memref.collapse_shape %rhs_m [[0, 1]] : memref<4x8xi8> into memref<32xi8>
+  %rhs_flat = vector.transfer_read %rhs_m1[%c0], %c0_i8 {in_bounds = [true]} : memref<32xi8>, vector<[32]xi8>
+
+  vector.print str "RHS:\n"
+  %rhs0 = vector.scalable.extract %rhs_flat[0] : vector<[16]xi8> from vector<[32]xi8>
+  %rhs1 = vector.scalable.extract %rhs_flat[16] : vector<[16]xi8> from vector<[32]xi8>
+  vector.print %rhs0 : vector<[16]xi8>
+  vector.print %rhs1 : vector<[16]xi8>
+
+  %rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
+
+  // Matrix multiplication
+  %0 = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
+  %1 = arith.extui %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
+  %2 = vector.contract {indexing_maps = #packed_maps,
+                        iterator_types = ["parallel", "parallel", "reduction"],
+                        kind = #vector.kind<add>} %0, %1, %acc
+    : vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32>
+
+  // Display the result of the multiplication
+  vector.print str "Result:\n"
+  %u0 = vector.extract %2[0] : vector<[4]xi32> from vector<4x[4]xi32>
+  %u1 = vector.extract %2[1] : vector<[4]xi32> from vector<4x[4]xi32>
+  %u2 = vector.extract %2[2] : vector<[4]xi32> from vector<4x[4]xi32>
+  %u3 = vector.extract %2[3] : vector<[4]xi32> from vector<4x[4]xi32>
+  vector.print %u0 : vector<[4]xi32>
+  vector.print %u1 : vector<[4]xi32>
+  vector.print %u2 : vector<[4]xi32>
+  vector.print %u3 : vector<[4]xi32>
+
+// CHECK: ( -27190, -28812, -30502, -23575 )
+// CHECK: (  -7613,  -8386, -15938,  -6521 )
+// CHECK: (   9468,  18750,   9199,   5764 )
+// CHECK: (  33655,  41064,  48900,  31627 )
+  return
+}
+
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir
new file mode 100644
index 0000000000000..7af0b2c3f1054
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir
@@ -0,0 +1,119 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented May 19, 2025

@llvm/pr-subscribers-mlir

Author: Momchil Velikov (momchil-velikov)

Changes

Patch is 30.80 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140573.diff

5 Files Affected:

  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir (+117)
  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir (+159)
  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir (+118)
  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir (+119)
  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-usmmla-4x8x4.mlir (+117)
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir
new file mode 100644
index 0000000000000..88534dd2aab1e
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-4x8x4.mlir
@@ -0,0 +1,117 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE:   --convert-vector-to-scf --convert-scf-to-cf  --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \
+// DEFINE:   --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm  --reconcile-unrealized-casts \
+// DEFINE: -o %t
+
+// DEFINE: %{entry_point} = main
+
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void  --march=aarch64 --mattr="+sve,+i8mm" \
+// DEFINE:    -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
+
+// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
+
+#packed_maps = [
+  affine_map<(d0, d1, d2) -> (d0, d2)>,
+  affine_map<(d0, d1, d2) -> (d1, d2)>,
+  affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+
+func.func private @setArmVLBits(%bits : i32)
+
+func.func @main() {
+  %c128 = arith.constant 128 : i32
+  func.call @setArmVLBits(%c128) : (i32) -> ()
+
+  %c0 = arith.constant 0 : index
+  %c0_i32 = arith.constant 0 : i32
+  %c0_i8 = arith.constant 0 : i8
+
+// Accumulator test data
+  %acc_cst = arith.constant dense<[[-44,  20,  44, -46],
+                                   [ -8,  25, -34,  26],
+                                   [-20, -36,  -3,  39],
+                                   [-48, -31, -25, -21]]> : vector<4x4xi32>
+  %acc_m = memref.alloca() : memref<4x4xi32>
+  vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<4x4xi32>, memref<4x4xi32>
+
+  %acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<4x4xi32> into memref<16xi32>
+  %acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<16xi32>, vector<[16]xi32>
+  %acc = vector.shape_cast %acc_flat : vector<[16]xi32> to vector<4x[4]xi32>
+
+  vector.print str "ACC:\n"
+  %acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<4x[4]xi32>
+  %acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<4x[4]xi32>
+  %acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<4x[4]xi32>
+  %acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<4x[4]xi32>
+  vector.print %acc0 : vector<[4]xi32>
+  vector.print %acc1 : vector<[4]xi32>
+  vector.print %acc2 : vector<[4]xi32>
+  vector.print %acc3 : vector<[4]xi32>
+
+  // LHS test data
+  %lhs_cst = arith.constant dense<[[-35, -27, -36, -31,  23, -34,  -8, -33],
+                                   [-20,  17, -32, -47,  37,  22,  -7, -21],
+                                   [ -7, -35,  20,  -4,  39,  46, -23,  40],
+                                   [ 40,  27,  37,  43,  38,  -6,  37,  49]]> : vector<4x8xi8>
+
+  %lhs_m = memref.alloca() : memref<4x8xi8>
+  vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
+  %lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<4x8xi8>, vector<4x8xi8>
+
+  vector.print str "LHS:\n"
+  %lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<4x8xi8>
+  %lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<4x8xi8>
+  %lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<4x8xi8>
+  %lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<4x8xi8>
+  vector.print %lhs0 : vector<8xi8>
+  vector.print %lhs1 : vector<8xi8>
+  vector.print %lhs2 : vector<8xi8>
+  vector.print %lhs3 : vector<8xi8>
+
+  // RHS test data
+  %rhs_cst = arith.constant dense<[[-17, -50,  -1,  48, -13,  22,  39,  33],
+                                   [-35, -24,  37, -32,  33,  30, -11, -17],
+                                   [-28,  31,   3, -44, -15, -27,  22,  35],
+                                   [-23,  39,  48,  26, -23,  32, -39, -38]]> : vector<4x8xi8>
+
+  %rhs_m = memref.alloca() : memref<4x8xi8>
+  vector.transfer_write %rhs_cst, %rhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
+
+  %rhs_m1 = memref.collapse_shape %rhs_m [[0, 1]] : memref<4x8xi8> into memref<32xi8>
+  %rhs_flat = vector.transfer_read %rhs_m1[%c0], %c0_i8 {in_bounds = [true]} : memref<32xi8>, vector<[32]xi8>
+
+  vector.print str "RHS:\n"
+  %rhs0 = vector.scalable.extract %rhs_flat[0] : vector<[16]xi8> from vector<[32]xi8>
+  %rhs1 = vector.scalable.extract %rhs_flat[16] : vector<[16]xi8> from vector<[32]xi8>
+  vector.print %rhs0 : vector<[16]xi8>
+  vector.print %rhs1 : vector<[16]xi8>
+
+  %rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
+
+  // Matrix multiplication
+  %0 = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
+  %1 = arith.extsi %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
+  %2 = vector.contract {indexing_maps = #packed_maps,
+                        iterator_types = ["parallel", "parallel", "reduction"],
+                        kind = #vector.kind<add>} %0, %1, %acc
+    : vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32>
+
+  // Display the result of the multiplication
+  vector.print str "Result:\n"
+  %u0 = vector.extract %2[0] : vector<[4]xi32> from vector<4x[4]xi32>
+  %u1 = vector.extract %2[1] : vector<[4]xi32> from vector<4x[4]xi32>
+  %u2 = vector.extract %2[2] : vector<[4]xi32> from vector<4x[4]xi32>
+  %u3 = vector.extract %2[3] : vector<[4]xi32> from vector<4x[4]xi32>
+  vector.print %u0 : vector<[4]xi32>
+  vector.print %u1 : vector<[4]xi32>
+  vector.print %u2 : vector<[4]xi32>
+  vector.print %u3 : vector<[4]xi32>
+
+// CHECK: ( -1999,  1941,   685, -2879 )
+// CHECK: ( -3705,  2952,   987,  -685 )
+// CHECK: (  2565,  4157, -1589,  -357 )
+// CHECK: (  2383, -2252,    32, -1365 )
+  return
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir
new file mode 100644
index 0000000000000..ce57be91fa540
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-smmla-8x8x8-vs2.mlir
@@ -0,0 +1,159 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE:   --convert-vector-to-scf --convert-scf-to-cf  --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \
+// DEFINE:   --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm  --reconcile-unrealized-casts \
+// DEFINE: -o %t
+
+// DEFINE: %{entry_point} = main
+
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void  --march=aarch64 --mattr="+sve,+i8mm" \
+// DEFINE:    -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
+
+// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
+
+#packed_maps = [
+  affine_map<(d0, d1, d2) -> (d0, d2)>,
+  affine_map<(d0, d1, d2) -> (d1, d2)>,
+  affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+
+func.func private @setArmVLBits(%bits : i32)
+
+func.func @main() {
+  %c256 = arith.constant 256 : i32
+  func.call @setArmVLBits(%c256) : (i32) -> ()
+
+  %c0 = arith.constant 0 : index
+  %c0_i32 = arith.constant 0 : i32
+  %c0_i8 = arith.constant 0 : i8
+
+
+  // Accumulator test data
+  %acc_cst = arith.constant dense<[[-44,  20,  44, -46,  -8,  25, -34,  26],
+                                   [-20, -36,  -3,  39, -48, -31, -25, -21],
+                                   [-35, -27, -36, -31,  23, -34,  -8, -33],
+                                   [-20,  17, -32, -47,  37,  22,  -7, -21],
+                                   [ -7, -35,  20,  -4,  39,  46, -23,  40],
+                                   [ 40,  27,  37,  43,  38,  -6,  37,  49],
+                                   [-17, -50,  -1,  48, -13,  22,  39,  33],
+                                   [-35, -24,  37, -32,  33,  30, -11, -17]]> : vector<8x8xi32>
+  %acc_m = memref.alloca() : memref<8x8xi32>
+  vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<8x8xi32>, memref<8x8xi32>
+
+  %acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<8x8xi32> into memref<64xi32>
+  %acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<64xi32>, vector<[32]xi32>
+  %acc = vector.shape_cast %acc_flat : vector<[32]xi32> to vector<8x[4]xi32>
+
+  vector.print str "ACC:\n"
+  %acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<8x[4]xi32>
+  %acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<8x[4]xi32>
+  %acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<8x[4]xi32>
+  %acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<8x[4]xi32>
+  %acc4 = vector.extract %acc[4] : vector<[4]xi32> from vector<8x[4]xi32>
+  %acc5 = vector.extract %acc[5] : vector<[4]xi32> from vector<8x[4]xi32>
+  %acc6 = vector.extract %acc[6] : vector<[4]xi32> from vector<8x[4]xi32>
+  %acc7 = vector.extract %acc[7] : vector<[4]xi32> from vector<8x[4]xi32>
+  vector.print %acc0 : vector<[4]xi32>
+  vector.print %acc1 : vector<[4]xi32>
+  vector.print %acc2 : vector<[4]xi32>
+  vector.print %acc3 : vector<[4]xi32>
+  vector.print %acc4 : vector<[4]xi32>
+  vector.print %acc5 : vector<[4]xi32>
+  vector.print %acc6 : vector<[4]xi32>
+  vector.print %acc7 : vector<[4]xi32>
+
+  // LHS test data
+  %lhs_cst = arith.constant dense<[[-28,  31,   3, -44, -15, -27,  22,  35],
+                                   [-23,  39,  48,  26, -23,  32, -39, -38],
+                                   [ -3,   9,  43, -30, -32,  39,  41, -39],
+                                   [-13, -21, -25,  27,  47, -36, -11, -11],
+                                   [ -4, -20,  36,  11,  13, -23,  24, -13],
+                                   [-20,  30,  -5,   1,  42, -37, -22,  35],
+                                   [-22,  38,  -4,  44,  25, -31,  23, -39],
+                                   [-45,  -4, -31, -24,  14, -41, -47,  22]]> : vector<8x8xi8>
+
+  %lhs_m = memref.alloca() : memref<8x8xi8>
+  vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<8x8xi8>, memref<8x8xi8>
+  %lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<8x8xi8>, vector<8x8xi8>
+
+  vector.print str "LHS:\n"
+  %lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<8x8xi8>
+  %lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<8x8xi8>
+  %lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<8x8xi8>
+  %lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<8x8xi8>
+  %lhs4 = vector.extract %lhs[4] : vector<8xi8> from vector<8x8xi8>
+  %lhs5 = vector.extract %lhs[5] : vector<8xi8> from vector<8x8xi8>
+  %lhs6 = vector.extract %lhs[6] : vector<8xi8> from vector<8x8xi8>
+  %lhs7 = vector.extract %lhs[7] : vector<8xi8> from vector<8x8xi8>
+  vector.print %lhs0 : vector<8xi8>
+  vector.print %lhs1 : vector<8xi8>
+  vector.print %lhs2 : vector<8xi8>
+  vector.print %lhs3 : vector<8xi8>
+  vector.print %lhs4 : vector<8xi8>
+  vector.print %lhs5 : vector<8xi8>
+  vector.print %lhs6 : vector<8xi8>
+  vector.print %lhs7 : vector<8xi8>
+
+  // RHS test data
+  %rhs_cst = arith.constant dense<[[-40, -11, -36,  36,  -1,  20,  14, -32],
+                                   [ 46, -45, -48, -46, -24,  31, -36,  22],
+                                   [  2,  36,  45, -29, -37, -49, -20, -35],
+                                   [ -6,  23,  23,  15,  20,   4,  -8,  -2],
+                                   [-35,  -6,  16,  49, -50,   9, -44,  13],
+                                   [ 24,   1,  -4, -44,  41,  15, -43,  44],
+                                   [ 44,   0, -10,  41,  22,  44, -40,   0],
+                                   [-33,  19,  27,  22,  38, -17,  23,  -9]]> : vector<8x8xi8>
+
+  %rhs_m = memref.alloca() : memref<8x8xi8>
+  vector.transfer_write %rhs_cst, %rhs_m[%c0, %c0] : vector<8x8xi8>, memref<8x8xi8>
+
+  %rhs_m1 = memref.collapse_shape %rhs_m [[0, 1]] : memref<8x8xi8> into memref<64xi8>
+  %rhs_flat = vector.transfer_read %rhs_m1[%c0], %c0_i8 {in_bounds = [true]} : memref<64xi8>, vector<[32]xi8>
+
+  vector.print str "RHS:\n"
+  %rhs0 = vector.scalable.extract %rhs_flat[ 0] : vector<[16]xi8> from vector<[32]xi8>
+  %rhs1 = vector.scalable.extract %rhs_flat[16] : vector<[16]xi8> from vector<[32]xi8>
+  vector.print %rhs0 : vector<[16]xi8>
+  vector.print %rhs1 : vector<[16]xi8>
+
+  %rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
+
+  // Matrix multiplication
+  %0 = arith.extsi %lhs : vector<8x8xi8> to vector<8x8xi32>
+  %1 = arith.extsi %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
+  %2 = vector.contract {indexing_maps = #packed_maps,
+                        iterator_types = ["parallel", "parallel", "reduction"],
+                        kind = #vector.kind<add>} %0, %1, %acc
+    : vector<8x8xi32>, vector<[4]x8xi32> into vector<8x[4]xi32>
+
+  // Display the result of the multilication
+  vector.print str "Result:\n"
+  %u0 = vector.extract %2[0] : vector<[4]xi32> from vector<8x[4]xi32>
+  %u1 = vector.extract %2[1] : vector<[4]xi32> from vector<8x[4]xi32>
+  %u2 = vector.extract %2[2] : vector<[4]xi32> from vector<8x[4]xi32>
+  %u3 = vector.extract %2[3] : vector<[4]xi32> from vector<8x[4]xi32>
+  %u4 = vector.extract %2[4] : vector<[4]xi32> from vector<8x[4]xi32>
+  %u5 = vector.extract %2[5] : vector<[4]xi32> from vector<8x[4]xi32>
+  %u6 = vector.extract %2[6] : vector<[4]xi32> from vector<8x[4]xi32>
+  %u7 = vector.extract %2[7] : vector<[4]xi32> from vector<8x[4]xi32>
+  vector.print %u0 : vector<[4]xi32>
+  vector.print %u1 : vector<[4]xi32>
+  vector.print %u2 : vector<[4]xi32>
+  vector.print %u3 : vector<[4]xi32>
+  vector.print %u4 : vector<[4]xi32>
+  vector.print %u5 : vector<[4]xi32>
+  vector.print %u6 : vector<[4]xi32>
+  vector.print %u7 : vector<[4]xi32>
+
+
+// CHECK: ( -2294, -1282,  2728,  -410, -1328,   882, -5498,   732 )
+// CHECK: (  1012, -4237,  4154,  2624,  5225, -2338,  2011,  1374 )
+// CHECK: (    -8, -1611,  2905,    -1, -1068, -3155, -2428,   153 )
+// CHECK: (  2034, -1768, -2092,   284,  -792,   -23,   668,  2172 )
+// CHECK: (  -248, -3728,  1214,   555,  -668, -2114, -1794,  2560 )
+// CHECK: ( -1484, -2642,   297,  1551,  -483,  3173,  -576,  2570 )
+// CHECK: (  3098, -7851,  1366,  1892,  -427, -4533,  -819,  4698 )
+// CHECK: (  -135,  1247,   765,  -479,  1245,  3074, -2281,   -23 )
+  return
+}
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir
new file mode 100644
index 0000000000000..f1f311ddb0c18
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-summla-4x8x4.mlir
@@ -0,0 +1,118 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s \
+// DEFINE:   --convert-vector-to-scf --convert-scf-to-cf  --convert-vector-to-llvm='enable-arm-sve enable-arm-i8mm' \
+// DEFINE:   --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm  --reconcile-unrealized-casts \
+// DEFINE: -o %t
+
+// DEFINE: %{entry_point} = main
+
+// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void  --march=aarch64 --mattr="+sve,+i8mm" \
+// DEFINE:    -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils
+
+// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s
+
+#packed_maps = [
+  affine_map<(d0, d1, d2) -> (d0, d2)>,
+  affine_map<(d0, d1, d2) -> (d1, d2)>,
+  affine_map<(d0, d1, d2) -> (d0, d1)>
+]
+
+func.func private @setArmVLBits(%bits : i32)
+
+func.func @main() {
+  %c128 = arith.constant 128 : i32
+  func.call @setArmVLBits(%c128) : (i32) -> ()
+
+  %c0 = arith.constant 0 : index
+  %c0_i32 = arith.constant 0 : i32
+  %c0_i8 = arith.constant 0 : i8
+
+// Accumulator test data
+  %acc_cst = arith.constant dense<[[-44,  20,  44, -46],
+                                   [ -8,  25, -34,  26],
+                                   [-20, -36,  -3,  39],
+                                   [-48, -31, -25, -21]]> : vector<4x4xi32>
+  %acc_m = memref.alloca() : memref<4x4xi32>
+  vector.transfer_write %acc_cst, %acc_m[%c0, %c0] : vector<4x4xi32>, memref<4x4xi32>
+
+  %acc_m1 = memref.collapse_shape %acc_m [[0, 1]] : memref<4x4xi32> into memref<16xi32>
+  %acc_flat = vector.transfer_read %acc_m1[%c0], %c0_i32 {in_bounds = [true]} : memref<16xi32>, vector<[16]xi32>
+  %acc = vector.shape_cast %acc_flat : vector<[16]xi32> to vector<4x[4]xi32>
+
+  vector.print str "ACC:\n"
+  %acc0 = vector.extract %acc[0] : vector<[4]xi32> from vector<4x[4]xi32>
+  %acc1 = vector.extract %acc[1] : vector<[4]xi32> from vector<4x[4]xi32>
+  %acc2 = vector.extract %acc[2] : vector<[4]xi32> from vector<4x[4]xi32>
+  %acc3 = vector.extract %acc[3] : vector<[4]xi32> from vector<4x[4]xi32>
+  vector.print %acc0 : vector<[4]xi32>
+  vector.print %acc1 : vector<[4]xi32>
+  vector.print %acc2 : vector<[4]xi32>
+  vector.print %acc3 : vector<[4]xi32>
+
+  // LHS test data
+  %lhs_cst = arith.constant dense<[[-35, -27, -36, -31,  23, -34,  -8, -33],
+                                   [-20,  17, -32, -47,  37,  22,  -7, -21],
+                                   [ -7, -35,  20,  -4,  39,  46, -23,  40],
+                                   [ 40,  27,  37,  43,  38,  -6,  37,  49]]> : vector<4x8xi8>
+
+  %lhs_m = memref.alloca() : memref<4x8xi8>
+  vector.transfer_write %lhs_cst, %lhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
+  %lhs = vector.transfer_read %lhs_m[%c0, %c0], %c0_i8 : memref<4x8xi8>, vector<4x8xi8>
+
+  vector.print str "LHS:\n"
+  %lhs0 = vector.extract %lhs[0] : vector<8xi8> from vector<4x8xi8>
+  %lhs1 = vector.extract %lhs[1] : vector<8xi8> from vector<4x8xi8>
+  %lhs2 = vector.extract %lhs[2] : vector<8xi8> from vector<4x8xi8>
+  %lhs3 = vector.extract %lhs[3] : vector<8xi8> from vector<4x8xi8>
+  vector.print %lhs0 : vector<8xi8>
+  vector.print %lhs1 : vector<8xi8>
+  vector.print %lhs2 : vector<8xi8>
+  vector.print %lhs3 : vector<8xi8>
+
+  // RHS test data
+  %rhs_cst = arith.constant dense<[[125, 171, 138, 187, 108, 175,  82,  99],
+                                   [221,  25, 164,  97, 156, 221, 218, 177],
+                                   [171, 160, 219, 191, 144,  45, 161, 210],
+                                   [223, 165, 123,  99, 108,  86,  37,  92]]> : vector<4x8xi8>
+
+  %rhs_m = memref.alloca() : memref<4x8xi8>
+  vector.transfer_write %rhs_cst, %rhs_m[%c0, %c0] : vector<4x8xi8>, memref<4x8xi8>
+
+  %rhs_m1 = memref.collapse_shape %rhs_m [[0, 1]] : memref<4x8xi8> into memref<32xi8>
+  %rhs_flat = vector.transfer_read %rhs_m1[%c0], %c0_i8 {in_bounds = [true]} : memref<32xi8>, vector<[32]xi8>
+
+  vector.print str "RHS:\n"
+  %rhs0 = vector.scalable.extract %rhs_flat[0] : vector<[16]xi8> from vector<[32]xi8>
+  %rhs1 = vector.scalable.extract %rhs_flat[16] : vector<[16]xi8> from vector<[32]xi8>
+  vector.print %rhs0 : vector<[16]xi8>
+  vector.print %rhs1 : vector<[16]xi8>
+
+  %rhs = vector.shape_cast %rhs_flat : vector<[32]xi8> to vector<[4]x8xi8>
+
+  // Matrix multiplication
+  %0 = arith.extsi %lhs : vector<4x8xi8> to vector<4x8xi32>
+  %1 = arith.extui %rhs : vector<[4]x8xi8> to vector<[4]x8xi32>
+  %2 = vector.contract {indexing_maps = #packed_maps,
+                        iterator_types = ["parallel", "parallel", "reduction"],
+                        kind = #vector.kind<add>} %0, %1, %acc
+    : vector<4x8xi32>, vector<[4]x8xi32> into vector<4x[4]xi32>
+
+  // Display the result of the multiplication
+  vector.print str "Result:\n"
+  %u0 = vector.extract %2[0] : vector<[4]xi32> from vector<4x[4]xi32>
+  %u1 = vector.extract %2[1] : vector<[4]xi32> from vector<4x[4]xi32>
+  %u2 = vector.extract %2[2] : vector<[4]xi32> from vector<4x[4]xi32>
+  %u3 = vector.extract %2[3] : vector<[4]xi32> from vector<4x[4]xi32>
+  vector.print %u0 : vector<[4]xi32>
+  vector.print %u1 : vector<[4]xi32>
+  vector.print %u2 : vector<[4]xi32>
+  vector.print %u3 : vector<[4]xi32>
+
+// CHECK: ( -27190, -28812, -30502, -23575 )
+// CHECK: (  -7613,  -8386, -15938,  -6521 )
+// CHECK: (   9468,  18750,   9199,   5764 )
+// CHECK: (  33655,  41064,  48900,  31627 )
+  return
+}
+
diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir
new file mode 100644
index 0000000000000..7af0b2c3f1054
--- /dev/null
+++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/contraction-ummla-4x8x4.mlir
@@ -0,0 +1,119 @@
+// REQUIRES: arm-emulator
+
+// DEFINE: %{compile} = mlir-opt %s ...
[truncated]

@momchil-velikov momchil-velikov force-pushed the users/momchil-velikov/vector-contract-i8mm-integration branch from 194c1c7 to 87f2964 Compare May 27, 2025 16:22
@momchil-velikov momchil-velikov force-pushed the users/momchil-velikov/vector-contract-i8mm-transform-dialect branch from 251f93e to 4ab7765 Compare May 27, 2025 16:22
@banach-space
Copy link
Contributor

Thanks - great to finally be reaching this stage! I have a few high-level questions and suggestions:

1. Why is the scalable dimension always [4]?

From the current tests, it looks like the scalable dim is always [4]. Could you remind me why that value is chosen?

2. Reduce duplication in the 4x8x4 tests

The current tests differ only in terms of input/output and extsi vs extui. It should be possible to reduce duplication by extracting shared logic into helpers, and writing 4 separate entry points (set via entry_point) to isolate the differences.

For example:

func.func @main_smmla() {
  // Init LHS, RHS, ACC
  // CHECK-LINES for LHS
  print(lhs);
  // CHECK-LINES for RHS
  print(rhs);
 
  arith.extsi (lhs)
  arith.extsi (rhs)
  vector.contract
  
  // CHECK-LINES for ACC
  print(acc);
}

This would keep the test logic focused and easier to maintain.

3. Add checks for generated IR (LLVM dialect)

It would be good to verify that the lowered IR includes the correct SME MMLA intrinsics. For example:

// CHECK-COUNT-4: llvm.intr.smmla

This would help confirm both correctness and that the expected number of operations are emitted.

4. Consider toggling VL within tests
Have you considered toggling the scalable vector length (VL) within the test? That would allow verifying behaviour for multiple VL values.

From what I can tell, this would only work if the inputs are generated inside a loop, similar to this example:

// Allocate memory.
%mem1 = memref.alloca(%svl_s, %svl_s) : memref<?x?xi32>
// Fill each "row" of "mem1" with row number.
//
// For example, assuming an SVL of 128-bits:
//
// 0, 0, 0, 0
// 1, 1, 1, 1
// 2, 2, 2, 2
// 3, 3, 3, 3
//
%init_0 = arith.constant 0 : i32
scf.for %i = %c0 to %svl_s step %c1 iter_args(%val = %init_0) -> (i32) {
%splat_val = vector.broadcast %val : i32 to vector<[4]xi32>
vector.store %splat_val, %mem1[%i, %c0] : memref<?x?xi32>, vector<[4]xi32>
%val_next = arith.addi %val, %c1_i32 : i32
scf.yield %val_next : i32
}

That might be a nice validation of the "scalability" aspect.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants